{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/dmlc/gluon-cv/blob/onnx/scripts/onnx/notebooks/pose/mobile_pose_mobilenetv2_1.0.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip3 install --upgrade onnxruntime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import numpy as np\n",
    "import onnxruntime as rt\n",
    "import urllib.request\n",
    "import os.path\n",
    "from PIL import Image\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def fetch_detect_model():\n",
    "    if not os.path.isfile(\"yolo3_mobilenet1.0_coco.onnx\"):\n",
    "        urllib.request.urlretrieve(\"https://apache-mxnet.s3-us-west-2.amazonaws.com/onnx/models/gluoncv-yolo3_mobilenet1.0_coco-115299e3.onnx\", filename=\"yolo3_mobilenet1.0_coco.onnx\")\n",
    "    return 'yolo3_mobilenet1.0_coco.onnx'\n",
    "    \n",
    "def fetch_model():\n",
    "    if not os.path.isfile(\"mobile_pose_mobilenetv2_1.0.onnx\"):\n",
    "        urllib.request.urlretrieve(\"https://apache-mxnet.s3-us-west-2.amazonaws.com/onnx/models/gluoncv-mobile_pose_mobilenetv2_1.0-2d772d7c.onnx\", filename=\"mobile_pose_mobilenetv2_1.0.onnx\")\n",
    "    return \"mobile_pose_mobilenetv2_1.0.onnx\"\n",
    "\n",
    "def prepare_img(img_path, input_shape):\n",
    "    # input_shape: BHWC\n",
    "    height, width = input_shape[1], input_shape[2]\n",
    "    img = Image.open(img_path).convert('RGB')\n",
    "    img = img.resize((width, height))\n",
    "    img = np.asarray(img)\n",
    "    plt_img = img\n",
    "    img = np.expand_dims(img, axis=0).astype('float32')\n",
    "\n",
    "    return plt_img, img\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Make sure to replace the image you want to use**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "detector = fetch_detect_model()\n",
    "model = fetch_model()\n",
    "img_path = 'Your image'\n",
    "plt_img, img = prepare_img(img_path, (1, 512, 512, 3))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "detector_session = rt.InferenceSession(detector, None)\n",
    "detector_input_name = detector_session.get_inputs()[0].name\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class_IDs, scores, bounding_boxs = detector_session.run([], {detector_input_name: img})\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "(Optional) We use mxnet and gluoncv to process the detector result.\n",
    "\n",
    "Feel free to process the result your own way\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip3 install --upgrade mxnet gluoncv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import mxnet as mx\n",
    "from gluoncv.data.transforms.pose import detector_to_simple_pose\n",
    "\n",
    "class_IDs = mx.nd.array(class_IDs)\n",
    "scores = mx.nd.array(scores)\n",
    "bounding_boxs = mx.nd.array(bounding_boxs)\n",
    "# Exported model has preprocess layer, so we pass mean=0 and std=1 to disable preprocess\n",
    "# when converting to pose input\n",
    "pose_input, upscale_bbox = detector_to_simple_pose(plt_img, class_IDs, scores, bounding_boxs,\n",
    "                                                   mean=(0,0,0), std=(1,1,1), output_shape=(256, 192))\n",
    "pose_input = np.transpose(pose_input, (0,2,3,1))\n",
    "pose_input *= 255  \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Create a onnx inference session and get the input name\n",
    "pose_session = rt.InferenceSession(model, None)\n",
    "pose_input_name = pose_session.get_inputs()[0].name   \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "predicted_heatmap = np.concatenate([pose_session.run([], {pose_input_name: pi.expand_dims(0).asnumpy()})[0] \n",
    "                                    for pi in pose_input])\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "(Optional) We use mxnet and gluoncv to visualize the  result.\n",
    "\n",
    "Feel free to visualize the result your own way\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import mxnet as mx\n",
    "from gluoncv import utils\n",
    "from matplotlib import pyplot as plt\n",
    "from gluoncv.data.transforms.pose import heatmap_to_coord\n",
    "\n",
    "predicted_heatmap = mx.nd.array(predicted_heatmap)\n",
    "pred_coords, confidence = heatmap_to_coord(predicted_heatmap, upscale_bbox)\n",
    "ax = utils.viz.plot_keypoints(plt_img, pred_coords, confidence,\n",
    "                              class_IDs, bounding_boxs, scores,\n",
    "                              box_thresh=0.5, keypoint_thresh=0.2)\n",
    "plt.show()\n",
    "    "
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
